/* ==================================================================   
 * Created [2009-4-27 下午11:32:55] by Jon.King 
 * ==================================================================  
 * TSS 
 * ================================================================== 
 * mailTo:jinpujun@hotmail.com
 * Copyright (c) Jon.King, 2009-2012 
 * ================================================================== 
*/

package com.jinhe.tss.core.web.wrapper;

import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Map.Entry;

import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import com.jinhe.tss.core.web.RewriteableHttpServletRequest;

/**
 * <p>
 * RewriteableRequestWrapper.java
 * 代理请求对象（可自定义或修改请求参数）
 * </p>
 *
 */
public class RewriteableHttpServletRequestWrapper extends HttpServletRequestWrapper 
				implements RewriteableHttpServletRequest {

	private Map<String, String[]> params = new HashMap<String, String[]>();

	private Map<String, String[]> headers = new HashMap<String, String[]>();

	private Map<String, Cookie> cookies = new HashMap<String, Cookie>();

	private String servletPath = null;

	/**
	 * <p>
	 * 获取可复写属性和参数等增强功能的Request对象
	 * </p>
	 * @param request
	 * @return
	 */
	public static RewriteableHttpServletRequest getRewriteableHttpServletRequest(HttpServletRequest request) {
		if (request instanceof RewriteableHttpServletRequest) {
			return (RewriteableHttpServletRequest) request;
		}
		return new RewriteableHttpServletRequestWrapper(request);
	}
 
	private RewriteableHttpServletRequestWrapper(HttpServletRequest request) {
		super(request);
	}

	public void addParameter(String name, String value) {
		Object oldValue = params.get(name);
		
		String[] newValues;
		if (oldValue == null) {
			newValues = new String[] { value };
		} 
		else if (oldValue instanceof String[]) {
			List<String> oldValues = Arrays.asList((String[]) oldValue);
			newValues = oldValues.toArray(new String[oldValues.size() + 1]);
			newValues[oldValues.size()] = value; // 将新value add到数组最后
		} else {
			newValues = new String[] { (String) oldValue, value };
		}
		params.put(name, newValues);
	}

	@SuppressWarnings("unchecked")
	public Map<String, String[]> getParameterMap() {
		Map<String, String[]> map = new HashMap<String, String[]>();
		map.putAll(super.getParameterMap());
		map.putAll(params);
		return map;
	}

	public Enumeration<String> getParameterNames() {
		return new Enumerator(getParameterMap().keySet());
	}
	// 取指定名字参数的值（多值）
	public String[] getParameterValues(String name) {
		return (String[])getParameterMap().get(name);
	}
	// 取指定名字参数的值（单值）
    public String getParameter(String name) {
        String[] values = (String[]) params.get(name);
        if (values != null && values.length > 0) {
            return values[0];
        }
        return super.getParameter(name);
    }

	public String getHeader(String name) {
		String[] value = (String[]) headers.get(name);
		if (value != null && value.length > 0) {
			return value[0];
		}
		return super.getHeader(name);
	}
	
	@SuppressWarnings("unchecked")
	public Enumeration<String> getHeaderNames() {
	    Set<String> set = new HashSet<String>();
	    
		Enumeration<String> e = super.getHeaderNames();
		while (e.hasMoreElements()) {
			set.add(e.nextElement());
		}
		
		set.addAll(headers.keySet());
		
		return new Enumerator(set);
	}

	@SuppressWarnings("unchecked")
	public Enumeration<String> getHeaders(String name) {
		Enumeration<String> e = super.getHeaders(name);
		Set<String> set = new HashSet<String>();
		String[] value = (String[]) headers.get(name);
		if (value != null && value.length > 0) {
			for (int i = 0; i < value.length; i++) {
				set.add(value[i]);
			}
		} else {
			while (e.hasMoreElements()) {
				set.add(e.nextElement());
			}
		}
		return new Enumerator(set);
	}

	public void setHeader(String name, String value) {
		headers.put(name, new String[] { value });
	}
 
	public void setServletPath(String servletPath) {
		this.servletPath = servletPath;
	}
 
	public String getServletPath() {
		if (this.servletPath != null) {
			return this.servletPath;
		}
		return super.getServletPath();
	}

	/**
	 * <p>
	 * 获取Cookie。
	 * cookie如果重名，优先取父目录下的cookie，遍历superCookies时候需要从后往前遍历，
     *      因为父目录下cookie相比子目录的同名cookie，在数组的靠前位置。
     * </p>
	 * @return
	 * @see javax.servlet.http.HttpServletRequestWrapper#getCookies()
	 */
	public Cookie[] getCookies() {
		Cookie[] superCookies = super.getCookies();
		Map<String, Cookie> cookiesMap = new HashMap<String, Cookie>();
		if (superCookies != null) {
			for (int i = superCookies.length - 1; i >= 0; i--) {
				Cookie cookie = superCookies[i];
				cookiesMap.put(cookie.getName(), cookie);
			}
		}
		
		cookiesMap.putAll(cookies);
		
		if (cookiesMap.isEmpty()) return null;
		
		Cookie[] allCookiesArray = new Cookie[cookiesMap.size()];
		int i = 0;
		for (Entry<String, Cookie> entry : cookiesMap.entrySet()) {
			allCookiesArray[i++] = entry.getValue();
		}
		return allCookiesArray;
	}

	/**
	 * <p>  设置cookie </p>
	 *
	 * @param cookie
	 */
	public void addCookie(Cookie cookie) {
		cookies.put(cookie.getName(), cookie);
	}

	private static class Enumerator implements Enumeration<String> {

		private Iterator<String> iter = null;

		public Enumerator(Set<String> set) {
			iter = set.iterator();
		}
 
		public boolean hasMoreElements() {
			return iter.hasNext();
		}
 
		public String nextElement() {
			return iter.next();
		}
	}
}
